"""
Spatial maps of differences between 
HYCOM & HRET during SWOT SCIENCE orbit

B. Yadidya
Feb 24, 2026
"""

import sys
sys.path.append('.')
from common_imports import *

# HYCOM average tides by constituent
hycom_avgs_tides = {
    'M2': np.array([0.14390679, 0.18910384, 0.14470802, 0.03543389]),
    'S2': np.array([0.0403017, 0.04215384, 0.04967995, 0.00610094]),
    'N2': np.array([0.01254888, 0.01407153, 0.0090684, 0.00236676]),
    'K1': np.array([0.03498775, 0.03580149, 0.02578041, 0.0020535]),
    'O1': np.array([0.02224772, 0.02307632, 0.01217359, 0.00149862])
}

# HRET average tides by constituent
hret_avgs_tides = {
    'M2': np.array([0.13160919, 0.18800078, 0.11823646, 0.03689674]),
    'S2': np.array([0.01398838, 0.01894715, 0.01680818, 0.00217958]),
    'N2': np.array([0.00157728, 0.0024333, 0.00080108, 0.00030441]),
    'K1': np.array([0.02639873, 0.03930837, 0.01977189, -0.00066817]),
    'O1': np.array([0.01484721, 0.02404912, 0.00479519, -4.7829247e-05])
}

# Total averages for HYCOM and HRET
total_hycom_avgs = np.array([0.23652243, 0.31060186, 0.24355097, 0.06327695])
total_hret_avgs = np.array([0.20459338, 0.30075043, 0.15984338, 0.03503372])


dir = "data/"
hret_ds = xr.open_dataset(dir + 'regridded_0.25deg_HRET22_VR_v2.0.1.nc')
hycom_ds = xr.open_dataset(dir + 'regridded_0.25deg_CONSTITUENTS_VR_v2.0.1.nc').assign_coords(longitude = hret_ds.longitude.values)

m2_diff = ((hycom_ds.var_red_constituents.sel(constituent = 'M2') - hret_ds.hret_ms_var.sel(constituent = 'M2')))
s2_diff = ((hycom_ds.var_red_constituents.sel(constituent = 'S2') - hret_ds.hret_ms_var.sel(constituent = 'S2')))
n2_diff = ((hycom_ds.var_red_constituents.sel(constituent = 'N2') - hret_ds.hret_ms_var.sel(constituent = 'N2')))
k1_diff = ((hycom_ds.var_red_constituents.sel(constituent = 'K1') - hret_ds.hret_ms_var.sel(constituent = 'K1')))
o1_diff = ((hycom_ds.var_red_constituents.sel(constituent = 'O1') - hret_ds.hret_ms_var.sel(constituent = 'O1')))

"""
# I used the 4km regridded datasets for the plots shown in the manuscript. 
# However for storage regions, 0.25 deg regridded datasets are uploaded on Dataverse.
hret_ds = xr.open_dataset(dir + 'HRET22_4km_VR_v2.0.1.nc')
hycom_ds = xr.open_dataset(dir + 'regridded_4km_CONSTITUENTS_VR_v2.0.1.nc').assign_coords(longitude = hret_ds.longitude.values)

m2_diff = coarsen_data((hycom_ds.var_red_constituents.sel(constituent = 'M2') - hret_ds.hret_ms_var.sel(constituent = 'M2')), .2)
s2_diff = coarsen_data((hycom_ds.var_red_constituents.sel(constituent = 'S2') - hret_ds.hret_ms_var.sel(constituent = 'S2')), .2)
n2_diff = coarsen_data((hycom_ds.var_red_constituents.sel(constituent = 'N2') - hret_ds.hret_ms_var.sel(constituent = 'N2')), .2)
k1_diff = coarsen_data((hycom_ds.var_red_constituents.sel(constituent = 'K1') - hret_ds.hret_ms_var.sel(constituent = 'K1')), .2)
o1_diff = coarsen_data((hycom_ds.var_red_constituents.sel(constituent = 'O1') - hret_ds.hret_ms_var.sel(constituent = 'O1')), .2)
"""

blue_color = '#4477AA'
# A vibrant, colorblind-safe orange
orange_red_color = '#EE7733'
regions = ['Global', 'Pacific', 'Indian', 'Atlantic']

fig_width_cm = 18.4
fig_width_in = fig_width_cm / 2.54
fig_height_max_cm = 19.0
fig_height_in = fig_height_max_cm / 2.54 

fig = plt.figure(figsize=(fig_width_in, fig_height_in),  constrained_layout=True)
gs = GridSpec(nrows=5, ncols=2, figure=fig, width_ratios=[2.2, 1.4], height_ratios=[1]*5)
map_axes = [fig.add_subplot(gs[i, 0]) for i in range(5)]
bar_axes = [fig.add_subplot(gs[i, 1]) for i in range(5)]

map_kwargs = dict(
    x='longitude',
    cmap=colormaps.prinsenvlag_r,
    add_colorbar=False
)

# Data and levels for maps and barplots
map_plot_configs = [
    (m2_diff, np.linspace(-1, 1, 21), 'M$_2$'),
    (s2_diff, np.linspace(-1, 1, 21), 'S$_2$'),
    (n2_diff, np.linspace(-0.5, 0.5, 21), 'N$_2$'),
    (k1_diff, np.linspace(-0.5, 0.5, 21), 'K$_1$'),
    (o1_diff, np.linspace(-0.5, 0.5, 21), 'O$_1$'),
]

bar_tide_names = ['M2', 'S2', 'N2', 'K1', 'O1']

label_kwargs = dict(fontsize=9, ha='right', va='center', rotation=90, fontweight='bold')

panel_label_props = dict(
    fontsize=9, 
    fontweight='bold', 
    ha='center',
    va='center',
    bbox=dict(
        boxstyle='round, pad=0.1',
        facecolor='white',
        edgecolor='none',
        alpha=1
    )
)

for i, (data_array, levels, label_text) in enumerate(map_plot_configs):
    ax = map_axes[i]
    plot_obj = data_array.plot.contourf(ax=ax, levels=levels, **map_kwargs)
    # Custom colorbar
    cax = inset_axes(ax, width="25%", height="5%", loc='lower center',
                     bbox_to_anchor=(0.09, 0.0, 1, 1), bbox_transform=ax.transAxes)
    cbar = plt.colorbar(plot_obj, cax=cax, orientation='horizontal')
    if levels[0] == -1:  # For M2 and S2
        cbar.set_ticks([-1, 0, 1])
    else:
        cbar.set_ticks([-0.5, 0, 0.5])
    cbar.ax.tick_params(which='minor', size=0, labelsize=8)
    cbar.ax.tick_params(which='major', size=0, labelsize=8)
    cbar.ax.minorticks_off()
    for label in cbar.ax.get_xticklabels():
        label.set_fontweight('normal')
    cbar.ax.xaxis.set_ticks_position('top')
    cbar.ax.set_xlabel('cm$^2$', fontsize=8)
    cbar.ax.xaxis.set_label_position('top')

    plot_global_map_on_axes_basemap(ax)
    ax.set_aspect('equal')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_title('')
    ax.set_ylim(-60, 60)
    # Subplot letter: ('c', 'e', 'g', 'i', 'k')
    ax.text(0.02, .92, f'{chr(ord("A")+2*i)}', transform=ax.transAxes, **panel_label_props, color = 'k')
    ax.text(-0.01, 0.5, label_text, transform=ax.transAxes, **label_kwargs)

# --- Bar Plots ---
bar_width = 0.4
x = np.arange(len(regions))

for i, tide_name in enumerate(bar_tide_names):
    ax = bar_axes[i]
    hycom_avgs = [avg * 100 for avg in hycom_avgs_tides[tide_name]]
    hret_avgs = [avg * 100 for avg in hret_avgs_tides[tide_name]]
    bar_hret = ax.bar(x - bar_width/2, hret_avgs, bar_width, label='HRET', color=blue_color, alpha = 0.9)
    bar_hycom = ax.bar(x + bar_width/2, hycom_avgs, bar_width, label='HYCOM', color=orange_red_color, alpha = 0.9)
    for bar, avg in zip(bar_hret, hret_avgs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{avg:.1f}', 
                va='bottom', ha='center', fontsize=7, color='black')
    for bar, avg in zip(bar_hycom, hycom_avgs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{avg:.1f}', 
                va='bottom', ha='center', fontsize=7, color='black')
    ax.set_xticks(x)
    # Subplot letter ('d', 'f', 'h', 'j', 'l')
    ax.text(0.04, .92, f'{chr(ord("B")+2*i)}', transform=ax.transAxes, **panel_label_props)
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.tick_params(direction='in')
    ax.tick_params(axis='y', labelsize=8, labelright=True, right=True, left=False, labelleft=False)
    ax.set_ylabel('mm$^2$', fontsize = 8)
    if i == len(bar_tide_names) - 1:
        ax.set_xticklabels(regions, fontsize=9)
        ax.legend(fontsize=9)
    else:
        ax.set_xticklabels([])

# Set axis limits for bar plots (adjust as needed)
bar_axes[0].set_ylim(0,22)    # M2
bar_axes[1].set_ylim(0,7)     # S2
bar_axes[2].set_ylim(0,1.8)   # N2
bar_axes[3].set_ylim(-.3,6)   # K1
bar_axes[4].set_ylim(-.3,4)   # O1

plt.savefig('Fig3_HYCOM_HRET_comparison_individual_tides.png', dpi=500, bbox_inches='tight')